#include <dlib/image_processing/frontal_face_detector.h>
#include <dlib/image_processing.h>
#include <dlib/image_io.h>

#include <filesystem>
#include <thread>

#include "inc.h"

using namespace dlib;

std::string work(
	frontal_face_detector* detector,
	shape_predictor* sp,
	const std::string& base_path,
	const image_t& image)
{
	array2d<rgb_pixel> img;
	load_image(img, image.first);
	pyramid_up(img);

	// Detect Faces
	std::vector<rectangle> dets = (*detector)(img);
	if (dets.empty()) return "";

	// Extract Features for Aligning
	// We only care about the first one
	std::vector<full_object_detection> shapes;
	full_object_detection shape = (*sp)(img, dets[0]);
	shapes.push_back(shape);

	// Crop and Align and Resize the Face
	dlib::array<array2d<rgb_pixel> > face_chips;
	extract_image_chips(img, get_face_chip_details(shapes, 224, 0.2), face_chips);
	array2d<rgb_pixel> aligned_img(224, 224);
	resize_image(tile_images(face_chips), aligned_img);

	// deal with the file name
	auto filename = image.second;
	std::replace(filename.begin(), filename.end(), '_', ' ');
	std::vector<std::string> filenames;
	std::stringstream ss(filename);
	std::string tmp;
	while (ss >> tmp) filenames.push_back(tmp);
	if (filenames.size() != 3) return "";

	auto path = std::filesystem::path(base_path);
	path /= filenames[0]; // X
	path /= filenames[1]; // Y
	std::filesystem::create_directories(path);

	path /= image.second; // X_Y_i
	path += ".jpg";

	// Save the image
	save_jpeg(aligned_img, path.string());
	return path.string();
}

void worker(
	int id,
	const std::string& base_path,
	const std::vector<image_t>& images,
	std::vector<std::string>& labels)
{
	frontal_face_detector detector = get_frontal_face_detector();
	shape_predictor sp;
	deserialize("assets/shape_predictor_68_face_landmarks.dat") >> sp;

	unsigned count = 0;
	unsigned total = images.size();
	unsigned ten_percent = total / 10;
	ten_percent = ten_percent == 0 ? 1 : ten_percent;
	auto start = std::chrono::steady_clock::now();
	for (auto& image:images)
	{
		try
		{
			count++;
			if (count % ten_percent == 0)
			{
				auto end = std::chrono::steady_clock::now();
				auto elapsed_seconds = std::chrono::duration<double>(end - start);
				std::cout << "[Thread" << id << "] Progress=" << 100 * (double)(count) / total << "% "
						  << "Elapsed Time=" << elapsed_seconds.count() << "s"
						  << std::endl;
			}

			auto path = work(&detector, &sp, base_path, image);
			if (path.empty())continue;
			labels.push_back(path);
		}
		catch (std::exception& e)
		{
			std::cout << "[Thread" << id << "] error: " << e.what() << std::endl;
		}
	}
}

void align(
	const std::string& base_path,
	const std::vector<image_t>& images,
	std::vector<std::string>& labels)
{
	// split into x parts
	unsigned parts_count = 8U;

	std::vector<std::vector<image_t>> sub_images;
	std::vector<std::vector<std::string>> sub_labels(parts_count);
	auto itr = images.begin();
	unsigned full_size = images.size();
	for (unsigned i = 0; i < parts_count; ++i)
	{
		auto part_size = full_size / (parts_count - i);
		full_size -= part_size;
		sub_images.emplace_back(itr, itr + part_size);
		itr += part_size;
	}

	// I don't know how to start a thread dynamically and join them, thus implementing a wait group
	std::thread t0(worker, 0, std::ref(base_path), std::ref(sub_images[0]), std::ref(sub_labels[0]));
	std::thread t1(worker, 1, std::ref(base_path), std::ref(sub_images[1]), std::ref(sub_labels[1]));
	std::thread t2(worker, 2, std::ref(base_path), std::ref(sub_images[2]), std::ref(sub_labels[2]));
	std::thread t3(worker, 3, std::ref(base_path), std::ref(sub_images[3]), std::ref(sub_labels[3]));
	std::thread t4(worker, 4, std::ref(base_path), std::ref(sub_images[4]), std::ref(sub_labels[4]));
	std::thread t5(worker, 5, std::ref(base_path), std::ref(sub_images[5]), std::ref(sub_labels[5]));
	std::thread t6(worker, 6, std::ref(base_path), std::ref(sub_images[6]), std::ref(sub_labels[6]));
	std::thread t7(worker, 7, std::ref(base_path), std::ref(sub_images[7]), std::ref(sub_labels[7]));

	if (t0.joinable())t0.join();
	if (t1.joinable())t1.join();
	if (t2.joinable())t2.join();
	if (t3.joinable())t3.join();
	if (t4.joinable())t4.join();
	if (t5.joinable())t5.join();
	if (t6.joinable())t6.join();
	if (t7.joinable())t7.join();

	// Combine them
	for (unsigned i = 0; i < parts_count; i++)
		labels.insert(labels.end(), sub_labels[i].begin(), sub_labels[i].end());
}